{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1fd97fa5-35c5-4318-94f8-2be0e04b8ed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import math\n",
    "from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f51feb63-a6d4-409a-be42-0fb7df600667",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43677da2-2850-420f-8270-1bf46db981ad",
   "metadata": {},
   "source": [
    "### 加载训练好的本地模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ada9af59-0e5a-4e55-8cae-f8cbff7b112f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_27116\\44386423.py:140: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(model_path))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TransformerClassifier(\n",
       "  (token_embedding): Embedding(21128, 256)\n",
       "  (positional_embedding): Embedding(128, 256)\n",
       "  (transformer_encoder): TransformerEncoder(\n",
       "    (layers): ModuleList(\n",
       "      (0-3): 4 x TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=256, out_features=512, bias=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "        (linear2): Linear(in_features=512, out_features=256, bias=True)\n",
       "        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.1, inplace=False)\n",
       "        (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (fc_out): Linear(in_features=256, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class TransformerClassifier(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim, num_heads, num_encoder_layers, ff_dim, num_classes,\n",
    "                     max_len=512, dropout_rate=0.1):\n",
    "        \"\"\"\n",
    "        初始化 Transformer 分类器\n",
    "        \n",
    "        Args:\n",
    "            vocab_size (int): 词汇表大小（tokenizer.vocab_size）。\n",
    "            embed_dim (int): 词嵌入和 Transformer 的维度（d_model）。\n",
    "            num_heads (int): 多头注意力机制的头数，必须能整除 embed_dim。\n",
    "            num_encoder_layers (int): Transformer Encoder 的层数。\n",
    "            ff_dim (int): 前馈网络中间层的维度（通常为 embed_dim 的 2-4 倍）。\n",
    "            num_classes (int): 分类任务的类别数（2 表示正/负）。\n",
    "            max_len (int): 最大序列长度，用于位置嵌入。\n",
    "            dropout_rate (float): Dropout 比率，用于正则化。\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.embed_dim = embed_dim\n",
    "        # 词嵌入层，将 token ID 映射到 embed_dim 维向量\n",
    "        self.token_embedding = nn.Embedding(vocab_size, embed_dim)\n",
    "        # 可学习的位置嵌入，为每个位置生成 embed_dim 维向量\n",
    "        self.positional_embedding = nn.Embedding(max_len, embed_dim)\n",
    "\n",
    "        # 定义单个 Transformer Encoder 层\n",
    "        encoder_layer = nn.TransformerEncoderLayer(\n",
    "            d_model=embed_dim,              # 模型维度\n",
    "            nhead=num_heads,                # 注意力头数\n",
    "            dim_feedforward=ff_dim,         # 前馈网络中间层维度\n",
    "            dropout=dropout_rate,            # Dropout 比率\n",
    "            batch_first=True,               # 输入/输出形状为 (batch, seq, feature)，适配常见数据格式\n",
    "            activation='gelu'               # 使用 GELU 激活函数，相比 ReLU 更平滑，有助于梯度流动\n",
    "        )\n",
    "\n",
    "        # 堆叠多个 Transformer Encoder 层\n",
    "        self.transformer_encoder = nn.TransformerEncoder(\n",
    "            encoder_layer,\n",
    "            num_layers=num_encoder_layers,\n",
    "            norm=nn.LayerNorm(embed_dim)    # 显式添加 LayerNorm，规范化输出\n",
    "        )\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout_rate)\n",
    "        \n",
    "        # 分类头：将 [CLS] token 的输出（embed_dim 维）映射到 num_classes 维\n",
    "        self.fc_out = nn.Linear(embed_dim, num_classes)\n",
    "\n",
    "        self.max_len = max_len # 存储最大序列长度，供位置编码使用\n",
    "\n",
    "        self._init_weights()\n",
    "\n",
    "    def _init_weights(self):\n",
    "        \"\"\"\n",
    "        初始化模型权重，使用 Xavier Uniform 初始化，适合 Transformer 模型。\n",
    "        避免初始权重过大或过小，加速收敛。\n",
    "        \"\"\"\n",
    "        for p in self.parameters():\n",
    "            if p.dim() > 1:  # 仅对二维以上参数（如线性层、嵌入层）应用\n",
    "                nn.init.xavier_uniform_(p)\n",
    "            # 对嵌入层可额外应用正态初始化\n",
    "            elif p.dim() == 2 and 'embedding' in p.name:\n",
    "                nn.init.normal_(p, mean=0.0, std=0.02)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        \"\"\"\n",
    "        前向传播，处理输入序列并输出分类 logits。\n",
    "\n",
    "        Args:\n",
    "            input_ids (torch.Tensor): 形状 (batch_size, seq_len)，词的 ID。\n",
    "            attention_mask (torch.Tensor): 形状 (batch_size, seq_len)，1 表示有效 token，0 表示 padding。\n",
    "\n",
    "        Returns:\n",
    "            torch.Tensor: 形状 (batch_size, num_classes)，分类 logits。\n",
    "        \"\"\"\n",
    "        seq_len = input_ids.size(1)  # 获取序列长度\n",
    "\n",
    "        # 1. 词嵌入\n",
    "        token_embeds = self.token_embedding(input_ids)  # (batch_size, seq_len, embed_dim)\n",
    "        token_embeds = token_embeds * math.sqrt(self.embed_dim) # 缩放嵌入，稳定训练\n",
    "\n",
    "        # 2. 位置编码\n",
    "        # 生成位置索引：(batch_size, seq_len)，每个样本重复 0 到 seq_len-1\n",
    "        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).repeat(input_ids.size(0), 1)\n",
    "        position_embeds = self.positional_embedding(positions)  # (batch_size, seq_len, embed_dim)\n",
    "        \n",
    "        # 词嵌入与位置嵌入相加\n",
    "        x = token_embeds + position_embeds\n",
    "        x = self.dropout(x)  # 在嵌入后应用 Dropout，增强鲁棒性\n",
    "\n",
    "        # Transformer Encoder需要 src_key_padding_mask\n",
    "        # attention_mask: 1是token, 0是padding.\n",
    "        # src_key_padding_mask: True表示该位置是padding, 需要被mask掉.\n",
    "        src_key_padding_mask = (attention_mask == 0)  # (batch_size, seq_len)\n",
    "\n",
    "        # 3. Transformer Encoder\n",
    "        # 输入形状: (batch_size, seq_len, embed_dim)\n",
    "        encoder_output = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)\n",
    "        # encoder_output shape: (batch_size, seq_len, embed_dim)\n",
    "\n",
    "        # 4. 分类\n",
    "        # 通常使用第一个token ([CLS] token)的输出来进行分类\n",
    "        cls_output = encoder_output[:, 0, :]  # (batch_size, embed_dim)\n",
    "        # 或者，可以对所有token的输出进行平均池化或最大池化\n",
    "        # cls_output = encoder_output.mean(dim=1) # 平均池化\n",
    "\n",
    "        cls_output = self.dropout(cls_output)\n",
    "        logits = self.fc_out(cls_output)  # (batch_size, num_classes)\n",
    "\n",
    "        return logits\n",
    "\n",
    "\n",
    "# 加载Bert 的分词器\n",
    "tokenizer_path = '../models/3_Transformer_Sentiment_Classification/bert-base-chinese'\n",
    "tokenizer = BertTokenizer.from_pretrained(tokenizer_path)\n",
    "    \n",
    "# 定义模型超参数\n",
    "VOCAB_SIZE = tokenizer.vocab_size  # 从之前加载的 BERT 分词器获取\n",
    "EMBED_DIM = 256                   # 嵌入维度，较小以减少计算量（BERT 常用 768）\n",
    "NUM_HEADS = 8                     # 多头注意力头数，需满足 embed_dim % num_heads == 0\n",
    "NUM_ENCODER_LAYERS = 4            # Encoder 层数，平衡性能与计算成本\n",
    "FF_DIM = 512                      # 前馈网络中间层维度，通常为 embed_dim 的 2-4 倍\n",
    "NUM_CLASSES = 2                   # 分类任务的类别数（正/负情感）\n",
    "DROPOUT_RATE = 0.1                # Dropout 比率，防止过拟合\n",
    "MAX_LEN = 128\n",
    "\n",
    "\n",
    "model = TransformerClassifier(\n",
    "    vocab_size=VOCAB_SIZE,\n",
    "    embed_dim=EMBED_DIM,\n",
    "    num_heads=NUM_HEADS,\n",
    "    num_encoder_layers=NUM_ENCODER_LAYERS,\n",
    "    ff_dim=FF_DIM,\n",
    "    num_classes=NUM_CLASSES,\n",
    "    max_len=MAX_LEN, # 从之前的配置中获取\n",
    "    dropout_rate=DROPOUT_RATE\n",
    ")\n",
    "model = model.to(device)\n",
    "\n",
    "model_path = '../models/3_Transformer_Sentiment_Classification/model_weights.pth'\n",
    "\n",
    "# 加载模型参数\n",
    "model.load_state_dict(torch.load(model_path))\n",
    "\n",
    "# 将模型设置为评估模式\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20cec762-f39c-42d0-b401-b0c7c030b5e4",
   "metadata": {},
   "source": [
    "## 模型推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c8a318aa-aeb7-49c8-b4f7-34b62a872256",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_sentiment(text, max_len=MAX_LEN):\n",
    "    #  文本预处理\n",
    "    encoding = tokenizer.encode_plus(\n",
    "        text,\n",
    "        add_special_tokens=True,  # 添加 [CLS] 和 [SEP]\n",
    "        max_length=max_len,\n",
    "        padding='max_length',     # 填充到 max_len\n",
    "        truncation=True,          # 截断超长文本\n",
    "        return_tensors='pt',      # 返回 PyTorch 张量\n",
    "        return_attention_mask=True\n",
    "    )\n",
    "    input_ids = encoding['input_ids'].to(device)          # (1, max_len)\n",
    "    attention_mask = encoding['attention_mask'].to(device)  # (1, max_len)\n",
    "    # 模型推理\n",
    "    with torch.no_grad():  # 禁用梯度计算，节省内存\n",
    "        logits = model(input_ids, attention_mask)  # (1, num_classes)\n",
    "        pred = torch.argmax(logits, dim=-1).item()  # 预测类别 (0 或 1)\n",
    "\n",
    "    # 转换标签\n",
    "    return '正面' if pred == 1 else '负面'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1fa98604-c3a9-47d3-a8be-7d11bcafb8e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Review : 虽没有第一部“我命由我不由天”的惊艳金句，但更多了些“怎能不知道这世间的规则，由谁所定？”的结构性思考，无量仙翁的“个体失范代替制度失范”真是最佳切口。狠狠期待第三部！\n",
      "sentiment : 正面\n",
      "Review : 我觉得中国这些人拍点电影，啥时候变成这种短视频短剧形式的切片合集了？一点点深度也没有了？太快餐了\n",
      "sentiment : 负面\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\develop\\anaconda\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\transformer.py:409: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\NestedTensorImpl.cpp:180.)\n",
      "  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)\n",
      "D:\\develop\\anaconda\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\transformer.py:720: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n",
      "  return torch._transformer_encoder_layer_fwd(\n"
     ]
    }
   ],
   "source": [
    "pos_review = \"虽没有第一部“我命由我不由天”的惊艳金句，但更多了些“怎能不知道这世间的规则，由谁所定？”的结构性思考，无量仙翁的“个体失范代替制度失范”真是最佳切口。狠狠期待第三部！\"\n",
    "neg_review = \"我觉得中国这些人拍点电影，啥时候变成这种短视频短剧形式的切片合集了？一点点深度也没有了？太快餐了\"\n",
    "print(f'Review : {pos_review}\\nsentiment : {predict_sentiment(pos_review)}')\n",
    "print(f'Review : {neg_review}\\nsentiment : {predict_sentiment(neg_review)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "592d8098-f44e-4d6d-8a46-c79300734aac",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
